library(MDEI)
library(MASS)
library(splines2)

## Function for making data ----
make.data <- function(n, k, pot.type, het = TRUE) {
  ##Generate data ----
  var.mat <- diag(k)
  var.mat[var.mat == 0] <- .5
  
  X <- mvrnorm(n, rep(0, k), Sig = var.mat)
  
  sign1 <- sign(X[, 1])
  g.lin <- (X[, 2] ^ 2 - 1) / 4
  
  treat <- g.lin + rnorm(n, sd = 1)
  
  if (pot.type == 5)
    treat <- g.lin * sign1 + rnorm(n, sd = 1)
  
  if (pot.type == 1) {
    te.eff <- treat
    mce.true = treat * 0 + 1
    rmy <- X[, 1] + (X[, 2] - 1) / 4
  }
  if (pot.type == 2) {
    te.eff <- treat
    mce.true = treat * 0 + 1
    rmy <- X[, 1] + (X[, 2] - 1) ^ 2 / 4
  }
  if (pot.type == 3) {
    te.eff <-
      4 * sin(treat)
    mce.true = 4 * cos(treat)
    rmy <- X[, 1] + (X[, 2] - 1) ^ 2 / 4
  }
  if (pot.type == 4) {
    te.eff <-
      4 * sin(treat) * X[, 1]
    mce.true = 4 * cos(treat) * X[, 1]
    rmy <- (X[, 2] - 1) ^ 2 / 4
  }
  if (pot.type == 5) {
    te.eff <- 4 * treat ^ 2 * sign1
    mce.true = 8 * treat * sign1
    rmy <- (X[, 2] - 1) ^ 2 / 4
  }
  
  
  
  fits.true <- te.eff + rmy
  errs <- rnorm(n, 0, (X[, 2] ^ 2 + 1) ^ -.5)
  if (!het)
    errs <- rnorm(n)
  errs <- errs / sd(errs) * sd(fits.true)
  y <- fits.true + errs
  
  obs <- y
  
  output <-
    list(
      "obs" = obs,
      "treat" = treat,
      "X" = X,
      "mce.true" = as.vector(mce.true),
      "fits.true" = as.vector(fits.true)
    )
}


##  Plotting function ----
plotrow50k <- function(m1, md.ex) {
  s1 <- m1$X[, 1] > 0
  plot(
    m1$treat[s1],
    m1$mce.true[s1],
    #ylim=lims.use,
    xlim = range(m1$treat),
    pch = "",
    xlab = "",
    ylab = "",
    cex.axis = 2,
    ylim = c(-102, 87)
  )
  lines(m1$treat[s1], m1$mce.true[s1])
  cover <- apply(md.ex$CIs.tau - m1$mce.true, 1, prod) < 0
  segments(
    x0 = m1$treat[s1],
    x1 = m1$treat[s1],
    y0 = md.ex$CIs.tau[s1, 1],
    y1 = md.ex$CIs.tau[s1, 2],
    col = ifelse(cover[s1], gray(0), gray(.6))
  )
  points(m1$treat[s1], md.ex$tau.est[s1], pch = 19, cex = .5)
  mtext("Treatment",
        side = 1 ,
        line = 3,
        cex = 2)
  mtext("Fitted Value",
        side = 2 ,
        line = 3,
        cex = 2)
  # mtext("MDEI", side =3 , line=.2)
  legend(
    "bottomright",
    legend = c("Estimate", "Covers Truth", "Does Not Cover Truth"),
    lty = c(0, 1, 1, 1),
    col = c("black", "black", gray(.6)),
    pch = c(19, -1, -1),
    bty = "n",
    cex = 2
  )
  
  band.use <- .1
  if (length(m1$treat) == 50000)
    band.use <- .025
  plot(
    m1$treat,
    abs(m1$mce.true - md.ex$tau.est),
    pch = 19,
    cex = .2,
    xlab = "",
    ylab = "",
    cex.axis = 2
  )
  l1 <- lowess(abs(m1$mce.true - md.ex$tau.est) ~ m1$treat, f = band.use)
  lines(l1,
        lwd = 3,
        lty = 2,
        col = gray(.5))
  mtext("Treatment",
        side = 1 ,
        line = 3,
        cex = 2)
  mtext(
    "Approximation Error",
    side = 2 ,
    line = 3,
    cex = 2
  )
  
  plot(
    m1$treat,
    apply(md.ex$CIs.tau, 1, diff),
    pch = 19,
    cex = .2,
    xlab = "",
    ylab = "",
    ylim = c(0, 70),
    cex.axis = 2
  )
  l1 <- lowess(apply(md.ex$CIs.tau, 1, diff) ~ m1$treat, f = band.use)
  lines(l1,
        lwd = 3,
        lty = 2,
        col = gray(.5))
  mtext("Treatment",
        side = 1 ,
        line = 3,
        cex = 2)
  mtext("Width of Band",
        side = 2 ,
        line = 3,
        cex = 2)
  
}

##  Run with n=5000 ----
options(device = "quartz")
set.seed(1)
n <- 5000
k <- 5
pot.type <- 5
m1 <- data.ex <- make.data(n, k, pot.type)
set.seed(1)
md.ex <-
  MDEI(
    data.ex$obs,
    data.ex$treat,
    data.ex$X,
    samplesplit = TRUE,
    conformal = TRUE,
    splits = 50
  )

lims.use <- range(c(md.ex$CIs.tau))
s1 <- sign(m1$X[, 1]) == 1

##  Run with n=50000 ----
set.seed(1)
n <- 50000
k <- 5
pot.type <- 5
m2 <- data.ex <- make.data(n, k, pot.type)
set.seed(1)
md.ex2 <-
  MDEI(
    data.ex$obs,
    data.ex$treat,
    data.ex$X,
    samplesplit = TRUE,
    conformal = TRUE,
    splits = 10
  )


## Plot results ----
pdf("Sim2_50kplot.pdf", h = 8 * 2, w = 12 * 2)
par(
  mfrow = c(2, 3),
  oma = c(0, 4, 2.5, 0),
  mar = c(5, 5, .2, 1.5)
)
plotrow50k(m1, md.ex)
plotrow50k(m2, md.ex2)
mtext(
  c("n = 50,000", "n = 5,000"),
  side = 2,
  at = c(.275, .775),
  line = 1,
  outer = T,
  cex = 2
)
mtext(
  c("Coverage", "Approximation Error", "CI Width"),
  side = 3,
  at = c(.15, .5, .85),
  outer = T,
  line = 0,
  cex = 2.5,
  font = 2
)

dev.off()
